{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Checkout my [Twitter(@rohanpaul_ai)](https://twitter.com/rohanpaul_ai) for daily LLM bits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Falcon finetuning on openassistant-guanaco\n",
    "\n",
    "# [Link to my Youtube Video Explaining this whole Notebook](https://www.youtube.com/watch?v=fEzuBFi35J4&list=PLxqBkZuBynVTzqUQCQFgetR97y1X_1uCI&index=11&ab_channel=Rohan-Paul-AI)\n",
    "\n",
    "[![Imgur](https://imgur.com/DGiAiTI.png)](https://www.youtube.com/watch?v=fEzuBFi35J4&list=PLxqBkZuBynVTzqUQCQFgetR97y1X_1uCI&index=11&ab_channel=Rohan-Paul-AI)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from datasets import load_dataset\n",
    "from peft import LoraConfig\n",
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoTokenizer,\n",
    "    BitsAndBytesConfig,\n",
    "    HfArgumentParser,\n",
    "    TrainingArguments,\n",
    ")\n",
    "from peft.tuners.lora import LoraLayer\n",
    "from trl import SFTTrainer\n",
    "from dataclasses import dataclass, field\n",
    "from typing import Optional\n",
    "\n",
    "@dataclass\n",
    "class ModelArguments:\n",
    "    \"\"\"\n",
    "    Arguments for creating and preparing the model.\n",
    "    \"\"\"\n",
    "    model_name: str = field(\n",
    "        default=\"tiiuae/falcon-7b\",\n",
    "        metadata={\"help\": \"The model name or path from the Hugging Face hub.\"},\n",
    "    )\n",
    "    use_4bit: bool = field(\n",
    "        default=True,\n",
    "        metadata={\"help\": \"Activate 4bit precision base model loading\"},\n",
    "    )\n",
    "    use_nested_quant: bool = field(\n",
    "        default=False,\n",
    "        metadata={\"help\": \"Activate nested quantization for 4bit base models\"},\n",
    "    )\n",
    "    bnb_4bit_compute_dtype: str = field(\n",
    "        default=\"float16\",\n",
    "        metadata={\"help\": \"Compute dtype for 4bit base models\"},\n",
    "    )\n",
    "    bnb_4bit_quant_type: str = field(\n",
    "        default=\"nf4\",\n",
    "        metadata={\"help\": \"Quantization type: fp4 or nf4\"},\n",
    "    )\n",
    "    lora_alpha: int = field(default=16)\n",
    "    lora_dropout: float = field(default=0.1)\n",
    "    lora_r: int = field(default=64)\n",
    "\n",
    "@dataclass\n",
    "class ScriptArguments:\n",
    "    \"\"\"\n",
    "    Arguments for model training and data handling.\n",
    "    \"\"\"\n",
    "    local_rank: int = field(default=-1, metadata={\"help\": \"Used for multi-gpu\"})\n",
    "    per_device_train_batch_size: int = field(default=4)\n",
    "    per_device_eval_batch_size: Optional[int] = field(default=1)\n",
    "    gradient_accumulation_steps: Optional[int] = field(default=4)\n",
    "    learning_rate: Optional[float] = field(default=2e-4)\n",
    "    max_grad_norm: Optional[float] = field(default=0.3)\n",
    "    weight_decay: Optional[int] = field(default=0.001)\n",
    "    max_seq_length: Optional[int] = field(default=512)\n",
    "    dataset_name: Optional[str] = field(\n",
    "        default=\"timdettmers/openassistant-guanaco\",\n",
    "        metadata={\"help\": \"The preference dataset to use.\"},\n",
    "    )\n",
    "    num_train_epochs: Optional[int] = field(\n",
    "        default=1,\n",
    "        metadata={\"help\": \"The number of training epochs for the reward model.\"},\n",
    "    )\n",
    "    fp16: Optional[bool] = field(\n",
    "        default=False,\n",
    "        metadata={\"help\": \"Enables fp16 training.\"},\n",
    "    )\n",
    "    bf16: Optional[bool] = field(\n",
    "        default=False,\n",
    "        metadata={\"help\": \"Enables bf16 training.\"},\n",
    "    )\n",
    "    packing: Optional[bool] = field(\n",
    "        default=False,\n",
    "        metadata={\"help\": \"Use packing dataset creating.\"},\n",
    "    )\n",
    "    gradient_checkpointing: Optional[bool] = field(\n",
    "        default=True,\n",
    "        metadata={\"help\": \"Enables gradient checkpointing.\"},\n",
    "    )\n",
    "    optim: Optional[str] = field(\n",
    "        default=\"paged_adamw_32bit\",\n",
    "        metadata={\"help\": \"The optimizer to use.\"},\n",
    "    )\n",
    "    lr_scheduler_type: str = field(\n",
    "        default=\"constant\",\n",
    "        metadata={\"help\": \"Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis\"},\n",
    "    )\n",
    "    max_steps: int = field(default=10000, metadata={\"help\": \"How many optimizer update steps to take\"})\n",
    "    warmup_ratio: float = field(default=0.03, metadata={\"help\": \"Fraction of steps to do a warmup for\"})\n",
    "    group_by_length: bool = field(\n",
    "        default=True,\n",
    "        metadata={\n",
    "            \"help\": \"Group sequences into batches with same length. Saves memory and speeds up training considerably.\"\n",
    "        },\n",
    "    )\n",
    "    save_steps: int = field(default=10, metadata={\"help\": \"Save checkpoint every X updates steps.\"})\n",
    "    logging_steps: int = field(default=10, metadata={\"help\": \"Log every X updates steps.\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_peftconfig_tokenizer(args: ModelArguments):\n",
    "    \"\"\"\n",
    "    Create the model, tokenizer, and peft_config based on provided arguments.\n",
    "    \"\"\"\n",
    "    compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)\n",
    "\n",
    "    # Configure BitsAndBytes for model quantization\n",
    "    bnb_config = BitsAndBytesConfig(\n",
    "        load_in_4bit=args.use_4bit,\n",
    "        bnb_4bit_quant_type=args.bnb_4bit_quant_type,\n",
    "        bnb_4bit_compute_dtype=compute_dtype,\n",
    "        bnb_4bit_use_double_quant=args.use_nested_quant,\n",
    "    )\n",
    "\n",
    "    # Alert for bfloat16 acceleration support\n",
    "    if compute_dtype == torch.float16 and args.use_4bit:\n",
    "        major, _ = torch.cuda.get_device_capability()\n",
    "        if major >= 8:\n",
    "            print(\"=\" * 80)\n",
    "            print(\"Your GPU supports bfloat16, you can accelerate training with --bf16\")\n",
    "            print(\"=\" * 80)\n",
    "\n",
    "    # Load the model with quantization configuration\n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        args.model_name, quantization_config=bnb_config, device_map={\"\": 0}, trust_remote_code=True\n",
    "    )\n",
    "\n",
    "    # Define Lora Configuration\n",
    "    peft_config = LoraConfig(\n",
    "        lora_alpha=args.lora_alpha,\n",
    "        lora_dropout=args.lora_dropout,\n",
    "        r=args.lora_r,\n",
    "        bias=\"none\",\n",
    "        task_type=\"CAUSAL_LM\",\n",
    "        target_modules=[\n",
    "            \"query_key_value\",\n",
    "            \"dense\",\n",
    "            \"dense_h_to_4h\",\n",
    "            \"dense_4h_to_h\",\n",
    "        ],\n",
    "    )\n",
    "\n",
    "    # Load the tokenizer and set padding token\n",
    "    tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)\n",
    "\n",
    "    # Need to do below for models like Falcon-7B, GPT-2 etc,\n",
    "    # because it doesn't have an official pad token.\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "    return model, peft_config, tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_arguments():\n",
    "    \"\"\"\n",
    "    Parse Model and Script Arguments.\n",
    "    Returns:\n",
    "        ModelArguments, ScriptArguments\n",
    "    \"\"\"\n",
    "    parser = HfArgumentParser((ModelArguments, ScriptArguments))\n",
    "    return parser.parse_args_into_dataclasses()\n",
    "\n",
    "def load_training_data(dataset_name: str):\n",
    "    \"\"\"\n",
    "    Load dataset for training.\n",
    "    Args:\n",
    "        dataset_name (str): Name or path of the dataset.\n",
    "    Returns:\n",
    "        Dataset object\n",
    "    \"\"\"\n",
    "    return load_dataset(dataset_name, split=\"train\")\n",
    "\n",
    "def get_training_args(script_args: ScriptArguments):\n",
    "    \"\"\"\n",
    "    Get Training Arguments from ScriptArguments.\n",
    "    Args:\n",
    "        script_args (ScriptArguments): Parsed ScriptArguments.\n",
    "    Returns:\n",
    "        TrainingArguments\n",
    "    \"\"\"\n",
    "    return TrainingArguments(\n",
    "        output_dir=\"./results\",\n",
    "        per_device_train_batch_size = script_args.per_device_train_batch_size,\n",
    "        gradient_accumulation_steps=script_args.gradient_accumulation_steps,\n",
    "        optim=script_args.optim,\n",
    "        save_steps=script_args.save_steps,\n",
    "        logging_steps=script_args.logging_steps,\n",
    "        learning_rate=script_args.learning_rate,\n",
    "        fp16=script_args.fp16,\n",
    "        bf16=script_args.bf16,\n",
    "        max_grad_norm=script_args.max_grad_norm,\n",
    "        max_steps=script_args.max_steps,\n",
    "        warmup_ratio=script_args.warmup_ratio,\n",
    "        group_by_length=script_args.group_by_length,\n",
    "        lr_scheduler_type=script_args.lr_scheduler_type,\n",
    "    )\n",
    "\n",
    "def adjust_model_for_bf16(trainer, bf16: bool):\n",
    "    \"\"\"\n",
    "    Adjust Model Layers for bf16.\n",
    "    Args:\n",
    "        trainer (SFTTrainer): Initialized SFTTrainer object.\n",
    "        bf16 (bool): Flag to indicate usage of bf16.\n",
    "    \"\"\"\n",
    "    for name, module in trainer.model.named_modules():\n",
    "        if isinstance(module, LoraLayer) and bf16:\n",
    "            module = module.to(torch.bfloat16)\n",
    "        if \"norm\" in name:\n",
    "            module = module.to(torch.float32)\n",
    "        if \"lm_head\" in name or \"embed_tokens\" in name:\n",
    "            if hasattr(module, \"weight\") and bf16 and module.weight.dtype == torch.float32:\n",
    "                module = module.to(torch.bfloat16)\n",
    "\n",
    "# Main Execution:\n",
    "\n",
    "model_args, script_args = parse_arguments()\n",
    "\n",
    "model, peft_config, tokenizer = get_model_peftconfig_tokenizer(model_args)\n",
    "model.config.use_cache = False\n",
    "\n",
    "dataset = load_training_data(script_args.dataset_name)\n",
    "\n",
    "training_arguments = get_training_args(script_args)\n",
    "\n",
    "trainer = SFTTrainer(\n",
    "    model=model,\n",
    "    train_dataset=dataset,\n",
    "    peft_config=peft_config,\n",
    "    dataset_text_field=\"text\",\n",
    "    max_seq_length=script_args.max_seq_length,\n",
    "    tokenizer=tokenizer,\n",
    "    args=training_arguments,\n",
    "    packing=script_args.packing,\n",
    ")\n",
    "\n",
    "adjust_model_for_bf16(trainer, script_args.bf16)\n",
    "\n",
    "# Train the Model\n",
    "trainer.train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py10env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
